import os
import openai
import json
from typing import Any, List, Dict
from pipelines.prompta.utils import query2str
from prompta.core.language import BaseLanguage
from .probabilistic_abstract_oracle import ProbabilisticAbstractOracle


class BaseLLMOracle(ProbabilisticAbstractOracle):

    def __init__(
            self,
            language: BaseLanguage,
            model_name: str,
            *args: Any,
            **kwargs: Any
            ) -> None:
        super().__init__(language, *args, **kwargs)
        try:
            from prompta.utils.set_api import global_openai_client
        except:
            raise ImportError("Cannot set OpenAI API key!")

        assert model_name in ['gpt-4', 'gpt-4-turbo', 'gpt-4-turbo-preview', 'gpt-3.5-turbo', 'gpt-3.5-turbo-0125', 'Llama2']
        if 'gpt' in model_name.lower():
            self.model_name = model_name
            self.client = global_openai_client
        self.llm_resp_cache = {}

    def reset(self, language: BaseLanguage, exp_dir, alphabet=None, load_history=False):
        super().reset(language, exp_dir, alphabet=alphabet, load_history=load_history)
        if load_history:
            self.load_llm_resp_cache()
        else:
            self.llm_resp_cache = {}
    
    def _get_json_resp(self, queries, seed: int=0, retry: int=10, max_tokens=512):
        for i in range(retry):
            try:
                resp = self._get_openai_resp(queries, seed=seed + i, max_tokens=max_tokens)
                result = json.loads(resp)
                return result
            except json.decoder.JSONDecodeError:
                print(resp)
                max_tokens += 100
        return None

    def _construct_existence_message(self, query: str):
        raise NotImplementedError
        
    def _get_openai_resp(self, queries: List[Dict], max_tokens=50, n=1, seed=0):
        response = self.client.chat.completions.create(
            model=self.model_name, # Specify the GPT-4 engine
            response_format={"type": "json_object"},
            messages=queries,
            max_tokens=max_tokens, # Maximum number of tokens in the response
            n=n, # Number of completions to generate
            stop=None, # Token at which to stop generating further tokens
            temperature=0.2, # Controls the randomness of the response
            seed=seed
        )
        return response.choices[0].message.content
    
    def save_llm_resp_cache(self):
        if not os.path.exists(self.exp_dir):
            os.mkdir(self.exp_dir)
        json_format_dict = {}
        for k, v in self.llm_resp_cache.items():
            json_key = '('
            for l in k:
                json_key += f'\'{l}\'' + ','
            json_key = json_key + ')'
            json_format_dict[json_key] = v
        json_name = f"{self.language.context_name}.json"
        json_path = os.path.join(self.exp_dir, json_name)
        with open(json_path, "w") as f:
            json.dump(json_format_dict, f, indent=4)

    def load_llm_resp_cache(self):
        self.llm_resp_cache = {}
        json_name = f"{self.language.context_name}.json"
        json_path = os.path.join(self.exp_dir, json_name)
        if os.path.exists(json_path):
            try:
                with open(json_path, "r") as f:
                    json_format_dict = json.load(f)
                    self.llm_resp_cache = {}
                    for k, v in json_format_dict.items():
                        self.llm_resp_cache[eval(k)] = v
            except:
                pass

        for k, v in self.llm_resp_cache.items():
            self.query_history[k] = v['answer']
